package com.hys.app.service.oauth2.impl;

import cn.hutool.core.codec.Base64;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.crypto.digest.HMac;
import cn.hutool.crypto.digest.HmacAlgorithm;
import com.hys.app.framework.database.WebPage;
import com.hys.app.framework.exception.ServiceException;
import com.hys.app.framework.util.BeanUtils;
import com.hys.app.framework.util.StrUtils;
import com.hys.app.framework.util.StringUtil;
import com.hys.app.mapper.oauth2.OAuth2ClientMapper;
import com.hys.app.model.oauth2.constant.OAuth2RedisKey;
import com.hys.app.model.oauth2.dos.OAuth2ClientDO;
import com.hys.app.model.oauth2.dto.OAuth2ClientPageReqVO;
import com.hys.app.model.oauth2.dto.OAuth2ClientSaveReqVO;
import com.hys.app.service.oauth2.OAuth2ClientManager;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;

import java.util.Collection;

/**
 * OAuth2.0 Client Service 实现类
 * 从功能上，和 Spring Security OAuth 的 JdbcClientDetailsService 的功能，提供客户端的操作
 *
 * @author 张崧
 * @since 2024-02-20
 */
@Service
@Validated
@Slf4j
public class OAuth2ClientManagerImpl implements OAuth2ClientManager {

    @Autowired
    private OAuth2ClientMapper oauth2ClientMapper;

    @Autowired
    @Lazy
    private OAuth2ClientManager self;

    @Override
    public Long createOAuth2Client(OAuth2ClientSaveReqVO createReqVO) {
        validateClientIdExists(null, createReqVO.getClientId());
        // 插入
        OAuth2ClientDO client = BeanUtils.toBean(createReqVO, OAuth2ClientDO.class);
        oauth2ClientMapper.insert(client);
        return client.getId();
    }

    @Override
    @CacheEvict(cacheNames = OAuth2RedisKey.OAUTH_CLIENT, allEntries = true)
    public void updateOAuth2Client(OAuth2ClientSaveReqVO updateReqVO) {
        // 校验存在
        validateOAuth2ClientExists(updateReqVO.getId());
        // 校验 Client 未被占用
        validateClientIdExists(updateReqVO.getId(), updateReqVO.getClientId());

        // 更新
        OAuth2ClientDO updateObj = BeanUtils.toBean(updateReqVO, OAuth2ClientDO.class);
        oauth2ClientMapper.updateById(updateObj);
    }

    @Override
    @CacheEvict(cacheNames = OAuth2RedisKey.OAUTH_CLIENT, allEntries = true)
    public void deleteOAuth2Client(Long id) {
        // 校验存在
        validateOAuth2ClientExists(id);
        // 删除
        oauth2ClientMapper.deleteById(id);
    }

    private void validateClientIdExists(Long id, String clientId) {
        OAuth2ClientDO client = oauth2ClientMapper.selectByClientId(clientId);
        if (client == null) {
            return;
        }
        // 如果 id 为空，说明不用比较是否为相同 id 的客户端
        if (id == null) {
            throw new ServiceException("OAuth2 客户端id已存在");
        }
        if (!client.getId().equals(id)) {
            throw new ServiceException("OAuth2 客户端id已存在");
        }
    }

    @Override
    public OAuth2ClientDO getOAuth2Client(Long id) {
        return oauth2ClientMapper.selectById(id);
    }

    @Override
    @Cacheable(cacheNames = OAuth2RedisKey.OAUTH_CLIENT, key = "#clientId", unless = "#result == null")
    public OAuth2ClientDO getOAuth2ClientFromCache(String clientId) {
        return oauth2ClientMapper.selectByClientId(clientId);
    }

    @Override
    public WebPage<OAuth2ClientDO> getOAuth2ClientPage(OAuth2ClientPageReqVO pageReqVO) {
        return oauth2ClientMapper.selectPage(pageReqVO);
    }

    @Override
    public OAuth2ClientDO validOAuthClientFromCache(String clientId, String sign, String timestamp, String authorizedGrantType,
                                                    Collection<String> scopes, String redirectUri) {

        if (StringUtil.notEmpty(timestamp)) {
            // 校验时间戳，误差最大1分钟
            long cha = System.currentTimeMillis() - Long.parseLong(timestamp);
            log.info("{}时间差：{}", clientId, cha);
            int timestampError = 60 * 1000;
            if (Math.abs(cha) > timestampError) {
                log.warn("{}请求过期", clientId);
                throw new ServiceException("请求过期");
            }
        }

        // 校验客户端存在、且开启
        OAuth2ClientDO client = self.getOAuth2ClientFromCache(clientId);
        if (client == null) {
            throw new ServiceException("OAuth2客户端不存在");
        }
        if (client.getStatus() == 1) {
            throw new ServiceException("OAuth2客户端已禁用");
        }

        // 校验客户端密钥
        if (StrUtil.isNotEmpty(sign) && ObjectUtil.notEqual(createSign(client.getClientId(), client.getSecret(), timestamp), sign)) {
            throw new ServiceException("签名验证失败");
        }
        // 校验授权方式
        if (StrUtil.isNotEmpty(authorizedGrantType) && !CollUtil.contains(client.getAuthorizedGrantTypes(), authorizedGrantType)) {
            throw new ServiceException("不支持该授权类型");
        }
        // 校验授权范围
        if (CollUtil.isNotEmpty(scopes) && !CollUtil.containsAll(client.getScopes(), scopes)) {
            throw new ServiceException("授权范围过大");
        }
        // 校验回调地址
        if (StrUtil.isNotEmpty(redirectUri) && !StrUtils.startWithAny(redirectUri, client.getRedirectUris())) {
            throw new ServiceException(StrUtil.format("无效 redirect_uri: {}", redirectUri));
        }
        return client;
    }

    private static String createSign(String clientId, String secret, String timestamp) {
        String signContent = timestamp + clientId + timestamp;
        HMac mac = new HMac(HmacAlgorithm.HmacSHA256, secret.getBytes());
        return Base64.encode(mac.digest(signContent));
    }

    private void validateOAuth2ClientExists(Long id) {
        if (oauth2ClientMapper.selectById(id) == null) {
            throw new ServiceException("OAuth2 客户端不存在");
        }
    }

    public static void main(String[] args) {
        String clientId = "ceshi1";
        String clientSecert = "uV17DChbKsNN6hC4JgYIiQJfV47bfE";
        Long timestamp = System.currentTimeMillis();
        System.out.println("timestamp:" + timestamp);
        System.out.println("sign:" + createSign(clientId, clientSecert, timestamp.toString()));
    }

}
